% Main Simulation: iid case, figure plot for one typical replication

clear all
warning off
clc


% 0. parameter setup
TT = [50 100 200];
NN = [100 200 300];
KK = [2 4 6];
T0 = 1;
sigu = 5;
sigy = 1;
Phi0 = diag(1:0.5:5)+diag(ones(1,8)*0.1,1)+diag(ones(1,8)*0.1,-1);
L = [5 10 20];
TUN = 0.1:0.1:1;
 
% 1. main estimation
tic
w = cell(length(KK),1);
for s = 1:length(KK)
    T = TT(s);
    N = NN(s);    
    K = KK(s);        
    % 2. ready DGP
    [y0,X0,~,w0] = DGP_gen(N,T+T0,Phi0(1:K,1:K),sigu,sigy);           
    Xt = X0(1:T,:);
    yt = y0(1:T,:);
    Xe = X0(T+1:end,:);
    ye = y0(T+1:end,:);
    % 3. estimate weights for figure 1              
    w1 = est_cv(yt,Xt,TUN,5);          % L2RELAX nonlinear shrinkage
    w2 = est_cv(yt,Xt,TUN,1);          % LASSO
    w3 = est_cv(yt,Xt,TUN,2);          % RIDGE        
    w{s} = [w0 w1 w2 w3];        
end  
toc

% 4. plot figures    
figure (1)
EST = {'Oracle','L2Relax','LASSO','Ridge'};
for s = 1:length(w)
    wt = w{s};
    for i = 1:4
        int = (i-1)*3+s;
        subplot(4,3,int)        
        TAR = wt(:,i);
        plot(TAR,'k')
        if s == 1
            ylim([-0.01 0.03])
        elseif s == 2
            ylim([-0.01 0.02])
        elseif s == 3
            ylim([-0.005 0.015])
        end
        ylabel('Weight')
        title(['(' char(int+96) ') ' EST{i} ', K=' num2str(KK(s))])
    end        
end

% % 5. plot the convergence paths in the appendix
% TUN2 = 0:0.1:5;
% TUN3 = 0:1:50;
% T = TT(1);
% N = NN(1);    
% K = KK(1);     
% path0 = zeros(N,length(TUN2));
% path1 = zeros(N,length(TUN3));
% [y0,X0,~,w0] = DGP_gen(N,T+T0,Phi0(1:K,1:K),sigu,sigy);           
% Xt = X0(1:T,:);
% yt = y0(1:T,:);
% parfor i = 1:length(TUN2)
%     path0(:,i) = l2relax0(yt,Xt,TUN2(i),1);        
% end  
% parfor i = 1:length(TUN3)    
%     path1(:,i) = ridge_est(yt,Xt,TUN3(i),1);    
% end  
% figure (2)
% P = {path0,path1};
% TUN = {TUN2,TUN3};
% VN = {'L2Relax','Ridge'};
% for i = 1:length(P)
%     subplot(1,2,i)
%     plot(TUN{i},P{i})
%     xlabel('\tau')
%     ylabel('Weight')
%     ylim([-0.04 0.06])
%     title(['(' char(i+96) ') ' VN{i}])
% end

%% Nested Functions
function [y,x,R2,w0,SUB] = DGP_gen(N,T,Phi0,sigu,sigy)
    % 0. basic parameters
    K = size(Phi0,1);
    Nk = N/K;
    Phi = kron(Phi0,ones(Nk));
    w0 = kron(Phi0^(-1)*ones(K,1)/(ones(1,K)*Phi0^(-1)*ones(K,1)),ones(Nk,1)/Nk);
    Ou = sigu^2*eye(N);     % Omega_u
    SUB = w0'*(Ou-Ou*(Phi+Ou)^(-1)*Ou)*w0+sigy^2;
    R2 = 1-(SUB)/(w0'*Phi*w0+sigy^2);
    % 1. generate DGP
    eta = randn(T,N);
    u = randn(T,N)*sigu;
    uy = randn(T,1)*sigy;
    x = eta*sqrt(Phi)+u;
    y = eta*sqrt(Phi)*w0+uy;
end

function w = ridge_est(ytrain,XTRAIN,tun,opt)
    if opt == 0         % no sum-to-1 condition
        w = ridge(ytrain,XTRAIN,tun);
    elseif opt == 1     % sum-to-1 condition
        warning off
        N = size(XTRAIN,2);
        options = optimset('display','none');
        w0 = ones(N,1)/N;
        w = fmincon(@(w) ridge_fun(w,XTRAIN,ytrain,tun),w0,[],[],ones(1,N),1,[],[],[],options);
    end
end

function cri = ridge_fun(w,X,y,tun)
    N = size(X,2);
    E = y-X*w;    
    SIG = analytical_shrinkage(E);
    cri = SIG/2+tun*((w-1/N)'*(w-1/N));     % recentered around 1/N
end

